{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {},
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W3D4_ReinforcementLearning/student/W3D4_Tutorial3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a> &nbsp; <a href=\"https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/NeoNeuron/professional-workshop-3/master/tutorials/W10_ReinforcementLearning/student/W10_Tutorial3.ipynb\" target=\"_parent\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open in Kaggle\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Tutorial 3: Learning to Act: Q-Learning\n",
    "\n",
    "__Content creators:__ Marcelo Mattar and Eric DeWitt with help from Byron Galbraith\n",
    "\n",
    "__Content modified:__ Kai Chen"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "\n",
    "# Tutorial Objectives\n",
    "  \n",
    "In this tutorial you will learn how to act in the more realistic setting of sequential decisions, formalized by Markov Decision Processes (MDPs). In a sequential decision problem, the actions executed in one state not only may lead to immediate rewards (as in a bandit problem), but may also affect the states experienced next (unlike a bandit problem). Each individual action may therefore affect affect all future rewards. Thus, making decisions in this setting requires considering each action in terms of their expected **cumulative** future reward.\n",
    "\n",
    "We will consider here the example of spatial navigation, where actions (movements) in one state (location) affect the states experienced next, and an agent might need to execute a whole sequence of actions before a reward is obtained.\n",
    "\n",
    "By the end of this tutorial, you will learn\n",
    "* what grid worlds are and how they help in evaluating simple reinforcement learning agents\n",
    "* the basics of the Q-learning algorithm for estimating action values\n",
    "* how the concept of exploration and exploitation, reviewed in the bandit case, also applies to the sequential decision setting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.signal import convolve as conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@title Figure Settings\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "nma_style = {\n",
    "    'figure.figsize' : (8, 6),\n",
    "    'figure.autolayout' : True,\n",
    "    'font.size' : 15,\n",
    "    'xtick.labelsize' : 'small',\n",
    "    'ytick.labelsize' : 'small',\n",
    "    'legend.fontsize' : 'small',\n",
    "    'axes.spines.top' : False,\n",
    "    'axes.spines.right' : False,\n",
    "    'xtick.major.size' : 5,\n",
    "    'ytick.major.size' : 5,\n",
    "}\n",
    "for key, value in nma_style.items():\n",
    "    plt.rcParams[key] = value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@title Plotting Functions\n",
    "\n",
    "def plot_state_action_values(env, value, ax=None):\n",
    "  \"\"\"\n",
    "  Generate plot showing value of each action at each state.\n",
    "  \"\"\"\n",
    "  if ax is None:\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "  for a in range(env.n_actions):\n",
    "    ax.plot(range(env.n_states), value[:, a], marker='o', linestyle='--')\n",
    "  ax.set(xlabel='States', ylabel='Values')\n",
    "  ax.legend(['R','U','L','D'], loc='lower right')\n",
    "\n",
    "\n",
    "def plot_quiver_max_action(env, value, ax=None):\n",
    "  \"\"\"\n",
    "  Generate plot showing action of maximum value or maximum probability at\n",
    "    each state (not for n-armed bandit or cheese_world).\n",
    "  \"\"\"\n",
    "  if ax is None:\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "  X = np.tile(np.arange(env.dim_x), [env.dim_y,1]) + 0.5\n",
    "  Y = np.tile(np.arange(env.dim_y)[::-1][:,np.newaxis], [1,env.dim_x]) + 0.5\n",
    "  which_max = np.reshape(value.argmax(axis=1), (env.dim_y,env.dim_x))\n",
    "  which_max = which_max[::-1,:]\n",
    "  U = np.zeros(X.shape)\n",
    "  V = np.zeros(X.shape)\n",
    "  U[which_max == 0] = 1\n",
    "  V[which_max == 1] = 1\n",
    "  U[which_max == 2] = -1\n",
    "  V[which_max == 3] = -1\n",
    "\n",
    "  ax.quiver(X, Y, U, V)\n",
    "  ax.set(\n",
    "      title='Maximum value/probability actions',\n",
    "      xlim=[-0.5, env.dim_x+0.5],\n",
    "      ylim=[-0.5, env.dim_y+0.5],\n",
    "  )\n",
    "  ax.set_xticks(np.linspace(0.5, env.dim_x-0.5, num=env.dim_x))\n",
    "  ax.set_xticklabels([\"%d\" % x for x in np.arange(env.dim_x)])\n",
    "  ax.set_xticks(np.arange(env.dim_x+1), minor=True)\n",
    "  ax.set_yticks(np.linspace(0.5, env.dim_y-0.5, num=env.dim_y))\n",
    "  ax.set_yticklabels([\"%d\" % y for y in np.arange(0, env.dim_y*env.dim_x,\n",
    "                                                  env.dim_x)])\n",
    "  ax.set_yticks(np.arange(env.dim_y+1), minor=True)\n",
    "  ax.grid(which='minor',linestyle='-')\n",
    "\n",
    "\n",
    "def plot_heatmap_max_val(env, value, ax=None):\n",
    "  \"\"\"\n",
    "  Generate heatmap showing maximum value at each state\n",
    "  \"\"\"\n",
    "  if ax is None:\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "  if value.ndim == 1:\n",
    "      value_max = np.reshape(value, (env.dim_y,env.dim_x))\n",
    "  else:\n",
    "      value_max = np.reshape(value.max(axis=1), (env.dim_y,env.dim_x))\n",
    "  value_max = value_max[::-1,:]\n",
    "\n",
    "  im = ax.imshow(value_max, aspect='auto', interpolation='none', cmap='afmhot')\n",
    "  ax.set(title='Maximum value per state')\n",
    "  ax.set_xticks(np.linspace(0, env.dim_x-1, num=env.dim_x))\n",
    "  ax.set_xticklabels([\"%d\" % x for x in np.arange(env.dim_x)])\n",
    "  ax.set_yticks(np.linspace(0, env.dim_y-1, num=env.dim_y))\n",
    "  if env.name != 'windy_cliff_grid':\n",
    "      ax.set_yticklabels(\n",
    "          [\"%d\" % y for y in np.arange(\n",
    "              0, env.dim_y*env.dim_x, env.dim_x)][::-1])\n",
    "  return im\n",
    "\n",
    "\n",
    "def plot_rewards(n_episodes, rewards, average_range=10, ax=None):\n",
    "  \"\"\"\n",
    "  Generate plot showing total reward accumulated in each episode.\n",
    "  \"\"\"\n",
    "  if ax is None:\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "  smoothed_rewards = (conv(rewards, np.ones(average_range), mode='same')\n",
    "                      / average_range)\n",
    "\n",
    "  ax.plot(range(0, n_episodes, average_range),\n",
    "          smoothed_rewards[0:n_episodes:average_range],\n",
    "          marker='o', linestyle='--')\n",
    "  ax.set(xlabel='Episodes', ylabel='Total reward')\n",
    "\n",
    "\n",
    "def plot_performance(env, value, reward_sums):\n",
    "  fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(16, 12))\n",
    "  plot_state_action_values(env, value, ax=axes[0,0])\n",
    "  plot_quiver_max_action(env, value, ax=axes[0,1])\n",
    "  plot_rewards(n_episodes, reward_sums, ax=axes[1,0])\n",
    "  im = plot_heatmap_max_val(env, value, ax=axes[1,1])\n",
    "  fig.colorbar(im)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 1: Markov Decision Processes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "**Grid Worlds**\n",
    "\n",
    "As pointed out, bandits only have a single state and immediate rewards for our actions. Many problems we are interested in have multiple states and delayed rewards, i.e. we won't know if the choices we made will pay off over time, or which actions we took contributed to the outcomes we observed.\n",
    "\n",
    "In order to explore these ideas, we turn the a common problem setting: the grid world. Grid worlds are simple environments where each state corresponds to a tile on a 2D grid, and the only actions the agent can take are to move up, down, left, or right across the grid tiles. The agent's job is almost always to find a way to a goal tile in the most direct way possible while overcoming some maze or other obstacles, either static or dynamic.\n",
    "\n",
    "For our discussion we will be looking at the classic Cliff World, or Cliff Walker, environment. This is a 4x10 grid with a starting position in the lower-left and the goal position in the lower-right. Every tile between these two is the \"cliff\", and should the agent enter the cliff, they will receive a -100 reward and be sent back to the starting position. Every tile other than the cliff produces a -1 reward when entered. The goal tile ends the episode after taking any action from it.\n",
    "\n",
    "<img alt=\"CliffWorld\" width=\"577\" height=\"308\" src=\"https://github.com/NeuromatchAcademy/course-content/blob/master/tutorials/static/W3D4_Tutorial3_CliffWorld.png?raw=true\">\n",
    "\n",
    "Given these conditions, the maximum achievable reward is -11 (1 up, 9 right, 1 down). Using negative rewards is a common technique to encourage the agent to move and seek out the goal state as fast as possible."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 2: Q-Learning\n",
    "\n",
    "Now that we have our environment, how can we solve it? \n",
    "\n",
    "One of the most famous algorithms for estimating action values (aka Q-values) is the Temporal Differences (TD) **control** algorithm known as *Q-learning* (Watkins, 1989). \n",
    "\n",
    "\\begin{align}\n",
    "Q(s_t,a_t) \\leftarrow Q(s_t,a_t) + \\alpha \\big(r_t + \\gamma\\max_\\limits{a} Q(s_{t+1},a_{t+1}) - Q(s_t,a_t)\\big)\n",
    "\\end{align}\n",
    "\n",
    "where $Q(s,a)$ is the value function for action $a$ at state $s$, $\\alpha$ is the learning rate, $r$ is the reward, and $\\gamma$ is the temporal discount rate.\n",
    "\n",
    "The expression $r_t + \\gamma\\max_\\limits{a} Q(s_{t+1},a_{t+1})$ is referred to as the TD target while the full expression \n",
    "\\begin{align}\n",
    "r_t + \\gamma\\max_\\limits{a} Q(s_{t+1},a_{t+1}) - Q(s_t,a_t),\n",
    "\\end{align}\n",
    "i.e. the difference between the TD target and the current Q-value, is referred to as the TD error, or reward prediction error.\n",
    "\n",
    "Because of the max operator used to select the optimal Q-value in the TD target, Q-learning directly estimates the optimal action value, i.e. the cumulative future reward that would be obtained if the agent behaved optimally, regardless of the policy currently followed by the agent. For this reason, Q-learning is referred to as an **off-policy** method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Coding Exercise 2: Implement the Q-learning algorithm\n",
    "\n",
    "In this exercise you will implement the Q-learning update rule described above. It takes in as arguments the previous state $s_t$, the action $a_t$ taken, the reward received $r_t$, the current state $s_{t+1}$, the Q-value table, and a dictionary of parameters that contain the learning rate $\\alpha$ and discount factor $\\gamma$. The method returns the updated Q-value table. For the parameter dictionary, $\\alpha$: `params['alpha']` and $\\gamma$: `params['gamma']`. \n",
    "\n",
    "Once we have our Q-learning algorithm, we will see how it handles learning to solve the Cliff World environment. \n",
    "\n",
    "You will recall from the previous tutorial that a major part of reinforcement learning algorithms are their ability to balance exploitation and exploration. For our Q-learning agent, we again turn to the epsilon-greedy strategy. At each step, the agent will decide with probability $1 - \\epsilon$ to use the best action for the state it is currently in by looking at the value function, otherwise just make a random choice.\n",
    "\n",
    "The process by which our the agent will interact with and learn about the environment is handled for you in the helper function `learn_environment`. This implements the entire learning episode lifecycle of stepping through the state observation, action selection (epsilon-greedy) and execution, reward, and state transition. Feel free to review that code later to see how it all fits together, but for now let's test out our agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute to get helper functions `epsilon_greedy`, `CliffWorld`, and `learn_environment`\n",
    "\n",
    "def epsilon_greedy(q, epsilon):\n",
    "  \"\"\"Epsilon-greedy policy: selects the maximum value action with probabilty\n",
    "  (1-epsilon) and selects randomly with epsilon probability.\n",
    "\n",
    "  Args:\n",
    "    q (ndarray): an array of action values\n",
    "    epsilon (float): probability of selecting an action randomly\n",
    "\n",
    "  Returns:\n",
    "    int: the chosen action\n",
    "  \"\"\"\n",
    "  if np.random.random() > epsilon:\n",
    "    action = np.argmax(q)\n",
    "  else:\n",
    "    action = np.random.choice(len(q))\n",
    "\n",
    "  return action\n",
    "\n",
    "\n",
    "class CliffWorld:\n",
    "  \"\"\"\n",
    "  World: Cliff world.\n",
    "  40 states (4-by-10 grid world).\n",
    "  The mapping from state to the grids are as follows:\n",
    "  30 31 32 ... 39\n",
    "  20 21 22 ... 29\n",
    "  10 11 12 ... 19\n",
    "  0  1  2  ...  9\n",
    "  0 is the starting state (S) and 9 is the goal state (G).\n",
    "  Actions 0, 1, 2, 3 correspond to right, up, left, down.\n",
    "  Moving anywhere from state 9 (goal state) will end the session.\n",
    "  Taking action down at state 11-18 will go back to state 0 and incur a\n",
    "      reward of -100.\n",
    "  Landing in any states other than the goal state will incur a reward of -1.\n",
    "  Going towards the border when already at the border will stay in the same\n",
    "      place.\n",
    "  \"\"\"\n",
    "  def __init__(self):\n",
    "    self.name = \"cliff_world\"\n",
    "    self.n_states = 40\n",
    "    self.n_actions = 4\n",
    "    self.dim_x = 10\n",
    "    self.dim_y = 4\n",
    "    self.init_state = 0\n",
    "\n",
    "  def get_outcome(self, state, action):\n",
    "    if state == 9:  # goal state\n",
    "      reward = 0\n",
    "      next_state = None\n",
    "      return next_state, reward\n",
    "    reward = -1  # default reward value\n",
    "    if action == 0:  # move right\n",
    "      next_state = state + 1\n",
    "      if state % 10 == 9:  # right border\n",
    "        next_state = state\n",
    "      elif state == 0:  # start state (next state is cliff)\n",
    "        next_state = None\n",
    "        reward = -100\n",
    "    elif action == 1:  # move up\n",
    "      next_state = state + 10\n",
    "      if state >= 30:  # top border\n",
    "        next_state = state\n",
    "    elif action == 2:  # move left\n",
    "      next_state = state - 1\n",
    "      if state % 10 == 0:  # left border\n",
    "        next_state = state\n",
    "    elif action == 3:  # move down\n",
    "      next_state = state - 10\n",
    "      if state >= 11 and state <= 18:  # next is cliff\n",
    "        next_state = None\n",
    "        reward = -100\n",
    "      elif state <= 9:  # bottom border\n",
    "        next_state = state\n",
    "    else:\n",
    "      print(\"Action must be between 0 and 3.\")\n",
    "      next_state = None\n",
    "      reward = None\n",
    "    return int(next_state) if next_state is not None else None, reward\n",
    "\n",
    "  def get_all_outcomes(self):\n",
    "    outcomes = {}\n",
    "    for state in range(self.n_states):\n",
    "      for action in range(self.n_actions):\n",
    "        next_state, reward = self.get_outcome(state, action)\n",
    "        outcomes[state, action] = [(1, next_state, reward)]\n",
    "    return outcomes\n",
    "\n",
    "\n",
    "def learn_environment(env, learning_rule, params, max_steps, n_episodes):\n",
    "  # Start with a uniform value function\n",
    "  value = np.ones((env.n_states, env.n_actions))\n",
    "\n",
    "  # Run learning\n",
    "  reward_sums = np.zeros(n_episodes)\n",
    "\n",
    "  # Loop over episodes\n",
    "  for episode in range(n_episodes):\n",
    "    state = env.init_state  # initialize state\n",
    "    reward_sum = 0\n",
    "\n",
    "    for t in range(max_steps):\n",
    "      # choose next action\n",
    "      action = epsilon_greedy(value[state], params['epsilon'])\n",
    "\n",
    "      # observe outcome of action on environment\n",
    "      next_state, reward = env.get_outcome(state, action)\n",
    "\n",
    "      # update value function\n",
    "      value = learning_rule(state, action, reward, next_state, value, params)\n",
    "\n",
    "      # sum rewards obtained\n",
    "      reward_sum += reward\n",
    "\n",
    "      if next_state is None:\n",
    "          break  # episode ends\n",
    "      state = next_state\n",
    "\n",
    "    reward_sums[episode] = reward_sum\n",
    "\n",
    "  return value, reward_sums"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def q_learning(state, action, reward, next_state, value, params):\n",
    "  \"\"\"Q-learning: updates the value function and returns it.\n",
    "\n",
    "  Args:\n",
    "    state (int): the current state identifier\n",
    "    action (int): the action taken\n",
    "    reward (float): the reward received\n",
    "    next_state (int): the transitioned to state identifier\n",
    "    value (ndarray): current value function of shape (n_states, n_actions)\n",
    "    params (dict): a dictionary containing the default parameters\n",
    "\n",
    "  Returns:\n",
    "    ndarray: the updated value function of shape (n_states, n_actions)\n",
    "  \"\"\"\n",
    "  # Q-value of current state-action pair\n",
    "  q = value[state, action]\n",
    "\n",
    "  ##########################################################\n",
    "  ## TODO for students: implement the Q-learning update rule\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student excercise: implement the Q-learning update rule\")\n",
    "  ##########################################################\n",
    "\n",
    "  # write an expression for finding the maximum Q-value at the current state\n",
    "  if next_state is None:\n",
    "    max_next_q = 0\n",
    "  else:\n",
    "    max_next_q = ...\n",
    "\n",
    "  # write the expression to compute the TD error\n",
    "  td_error = ...\n",
    "  # write the expression that updates the Q-value for the state-action pair\n",
    "  value[state, action] = ...\n",
    "\n",
    "  return value\n",
    "\n",
    "\n",
    "# set for reproducibility, comment out / change seed value for different results\n",
    "np.random.seed(1)\n",
    "\n",
    "# parameters needed by our policy and learning rule\n",
    "params = {\n",
    "  'epsilon': 0.1,  # epsilon-greedy policy\n",
    "  'alpha': 0.1,  # learning rate\n",
    "  'gamma': 1.0,  # discount factor\n",
    "}\n",
    "\n",
    "# episodes/trials\n",
    "n_episodes = 500\n",
    "max_steps = 1000\n",
    "\n",
    "# environment initialization\n",
    "env = CliffWorld()\n",
    "\n",
    "# solve Cliff World using Q-learning\n",
    "results = learn_environment(env, q_learning, params, max_steps, n_episodes)\n",
    "value_qlearning, reward_sums_qlearning = results\n",
    "\n",
    "# Plot results\n",
    "plot_performance(env, value_qlearning, reward_sums_qlearning)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "[*Click for solution*](https://github.com/NeoNeuron/professional-workshop-3/tree/master//tutorials/W10_ReinforcementLearning/solutions/W10_Tutorial3_Solution_6573228d.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=2267.0 height=1691.0 src=https://raw.githubusercontent.com/NeoNeuron/professional-workshop-3/master/tutorials/W10_ReinforcementLearning/static/W10_Tutorial3_Solution_6573228d_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "If all went well, we should see four plots that show different aspects on our agent's learning and progress.\n",
    "\n",
    "* The top left is a representation of the Q-table itself, showing the values for different actions in different states. Notably, going right from the starting state or down when above the cliff is clearly very bad.\n",
    "* The top right figure shows the greedy policy based on the Q-table, i.e. what action would the agent take if it only took its best guess in that state.\n",
    "* The bottom right is the same as the top, only instead of showing the action, it's showing a representation of the maximum Q-value at a particular state.\n",
    "* The bottom left is the actual proof of learning, as we see the total reward steadily increasing after each episode until asymptoting at the maximum possible reward of -11.\n",
    "\n",
    "Feel free to try changing the parameters or random seed and see how the agent's behavior changes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Summary\n",
    "\n",
    "In this tutorial you implemented a reinforcement learning agent based on Q-learning to solve the Cliff World environment. Q-learning combined the epsilon-greedy approach to exploration-expoitation with a table-based value function to learn the expected future rewards for each state."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Bonus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "## Bonus Section 1: SARSA\n",
    "\n",
    "An alternative to Q-learning, the SARSA algorithm also estimates action values. However, rather than estimating the optimal (off-policy) values, SARSA estimates the **on-policy** action value, i.e. the cumulative future reward that would be obtained if the agent behaved according to its current beliefs.\n",
    "\n",
    "\\begin{align}\n",
    "Q(s_t,a_t) \\leftarrow Q(s_t,a_t) + \\alpha \\big(r_t + \\gamma Q(s_{t+1},a_{t+1}) - Q(s_t,a_t)\\big)\n",
    "\\end{align}\n",
    "\n",
    "where, once again, $Q(s,a)$ is the value function for action $a$ at state $s$, $\\alpha$ is the learning rate, $r$ is the reward, and $\\gamma$ is the temporal discount rate.\n",
    "\n",
    "In fact, you will notices that the *only* difference between Q-learning and SARSA is the TD target calculation uses the policy to select the next action (in our case epsilon-greedy) rather than using the action that maximizes the Q-value."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Bonus Coding Exercise 1: Implement the SARSA algorithm\n",
    "\n",
    "In this exercise you will implement the SARSA update rule described above. Just like Q-learning, it takes in as arguments the previous state $s_t$, the action $a_t$ taken, the reward received $r_t$, the current state $s_{t+1}$, the Q-value table, and a dictionary of parameters that contain the learning rate $\\alpha$ and discount factor $\\gamma$. The method returns the updated Q-value table. You may use the `epsilon_greedy` function to acquire the next action. For the parameter dictionary, $\\alpha$: `params['alpha']`, $\\gamma$: `params['gamma']`, and $\\epsilon$: `params['epsilon']`. \n",
    "\n",
    "Once we have an implementation for SARSA, we will see how it tackles Cliff World. We will again use the same setup we tried with Q-learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def sarsa(state, action, reward, next_state, value, params):\n",
    "  \"\"\"SARSA: updates the value function and returns it.\n",
    "\n",
    "  Args:\n",
    "    state (int): the current state identifier\n",
    "    action (int): the action taken\n",
    "    reward (float): the reward received\n",
    "    next_state (int): the transitioned to state identifier\n",
    "    value (ndarray): current value function of shape (n_states, n_actions)\n",
    "    params (dict): a dictionary containing the default parameters\n",
    "\n",
    "  Returns:\n",
    "    ndarray: the updated value function of shape (n_states, n_actions)\n",
    "  \"\"\"\n",
    "  # value of previous state-action pair\n",
    "  q = value[state, action]\n",
    "\n",
    "  ##########################################################\n",
    "  ## TODO for students: implement the SARSA update rule\n",
    "  # Fill out function and remove\n",
    "  raise NotImplementedError(\"Student excercise: implement the SARSA update rule\")\n",
    "  ##########################################################\n",
    "\n",
    "  # select the expected value at current state based on our policy by sampling\n",
    "  # from it\n",
    "  if next_state is None:\n",
    "    policy_next_q = 0\n",
    "  else:\n",
    "    # write an expression for selecting an action using epsilon-greedy\n",
    "    policy_action = ...\n",
    "    # write an expression for obtaining the value of the policy action at the\n",
    "    # current state\n",
    "    policy_next_q = ...\n",
    "\n",
    "  # write the expression to compute the TD error\n",
    "  td_error = ...\n",
    "  # write the expression that updates the Q-value for the state-action pair\n",
    "  value[state, action] = ...\n",
    "\n",
    "  return value\n",
    "\n",
    "\n",
    "# set for reproducibility, comment out / change seed value for different results\n",
    "np.random.seed(1)\n",
    "\n",
    "# parameters needed by our policy and learning rule\n",
    "params = {\n",
    "  'epsilon': 0.1,  # epsilon-greedy policy\n",
    "  'alpha': 0.1,  # learning rate\n",
    "  'gamma': 1.0,  # discount factor\n",
    "}\n",
    "\n",
    "# episodes/trials\n",
    "n_episodes = 500\n",
    "max_steps = 1000\n",
    "\n",
    "# environment initialization\n",
    "env = CliffWorld()\n",
    "\n",
    "# learn Cliff World using Sarsa\n",
    "results = learn_environment(env, sarsa, params, max_steps, n_episodes)\n",
    "value_sarsa, reward_sums_sarsa = results\n",
    "\n",
    "# Plot results\n",
    "plot_performance(env, value_sarsa, reward_sums_sarsa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def sarsa(state, action, reward, next_state, value, params):\n",
    "  \"\"\"SARSA: updates the value function and returns it.\n",
    "\n",
    "  Args:\n",
    "    state (int): the current state identifier\n",
    "    action (int): the action taken\n",
    "    reward (float): the reward received\n",
    "    next_state (int): the transitioned to state identifier\n",
    "    value (ndarray): current value function of shape (n_states, n_actions)\n",
    "    params (dict): a dictionary containing the default parameters\n",
    "\n",
    "  Returns:\n",
    "    ndarray: the updated value function of shape (n_states, n_actions)\n",
    "  \"\"\"\n",
    "  # value of previous state-action pair\n",
    "  q = value[state, action]\n",
    "\n",
    "  # select the expected value at current state based on our policy by sampling\n",
    "  # from it\n",
    "  if next_state is None:\n",
    "    policy_next_q = 0\n",
    "  else:\n",
    "    # write an expression for selecting an action using epsilon-greedy\n",
    "    policy_action = epsilon_greedy(value[next_state], params['epsilon'])\n",
    "    # write an expression for obtaining the value of the policy action at the\n",
    "    # current state\n",
    "    policy_next_q = value[next_state, policy_action]\n",
    "\n",
    "  # write the expression to compute the TD error\n",
    "  td_error = reward + params['gamma'] * policy_next_q - q\n",
    "  # write the expression that updates the Q-value for the state-action pair\n",
    "  value[state, action] = q + params['alpha'] * td_error\n",
    "\n",
    "  return value\n",
    "\n",
    "\n",
    "# set for reproducibility, comment out / change seed value for different results\n",
    "np.random.seed(1)\n",
    "\n",
    "# parameters needed by our policy and learning rule\n",
    "params = {\n",
    "  'epsilon': 0.1,  # epsilon-greedy policy\n",
    "  'alpha': 0.1,  # learning rate\n",
    "  'gamma': 1.0,  # discount factor\n",
    "}\n",
    "\n",
    "# episodes/trials\n",
    "n_episodes = 500\n",
    "max_steps = 1000\n",
    "\n",
    "# environment initialization\n",
    "env = CliffWorld()\n",
    "\n",
    "# learn Cliff World using Sarsa\n",
    "results = learn_environment(env, sarsa, params, max_steps, n_episodes)\n",
    "value_sarsa, reward_sums_sarsa = results\n",
    "\n",
    "# Plot results\n",
    "plot_performance(env, value_sarsa, reward_sums_sarsa)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "We should see that SARSA also solves the task with similar looking outcomes to Q-learning. One notable difference is that SARSA seems to be skittsh around the cliff edge and often goes further away before coming back down to the goal.\n",
    "\n",
    "Again, feel free to try changing the parameters or random seed and see how the agent's behavior changes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "## Bonus Section 2: On-Policy vs Off-Policy\n",
    "   \n",
    "We have now seen an example of both on- and off-policy learning algorithms. Let's compare both Q-learning and SARSA reward results again, side-by-side, to see how they stack up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute to see visualization\n",
    "\n",
    "# parameters needed by our policy and learning rule\n",
    "params = {\n",
    "  'epsilon': 0.1,  # epsilon-greedy policy\n",
    "  'alpha': 0.1,  # learning rate\n",
    "  'gamma': 1.0,  # discount factor\n",
    "}\n",
    "\n",
    "# episodes/trials\n",
    "n_episodes = 500\n",
    "max_steps = 1000\n",
    "\n",
    "# environment initialization\n",
    "env = CliffWorld()\n",
    "\n",
    "# learn Cliff World using Sarsa\n",
    "np.random.seed(1)\n",
    "results = learn_environment(env, q_learning, params, max_steps, n_episodes)\n",
    "value_qlearning, reward_sums_qlearning = results\n",
    "np.random.seed(1)\n",
    "results = learn_environment(env, sarsa, params, max_steps, n_episodes)\n",
    "value_sarsa, reward_sums_sarsa = results\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.plot(reward_sums_qlearning, label='Q-learning')\n",
    "ax.plot(reward_sums_sarsa, label='SARSA')\n",
    "ax.set(xlabel='Episodes', ylabel='Total reward')\n",
    "plt.legend(loc='lower right');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "On this simple Cliff World task, Q-learning and SARSA are almost indistinguisable from a performance standpoint, but we can see that Q-learning has a slight-edge within the 500 episode time horizon. Let's look at the illustrated \"greedy policy\" plots again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute to see visualization\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 6))\n",
    "plot_quiver_max_action(env, value_qlearning, ax=ax1)\n",
    "ax1.set(title='Q-learning maximum value/probability actions')\n",
    "plot_quiver_max_action(env, value_sarsa, ax=ax2)\n",
    "ax2.set(title='SARSA maximum value/probability actions');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "What should immediately jump out is that Q-learning learned to go up, then immediately go to the right, skirting the cliff edge, until it hits the wall and goes down to the goal. The policy further away from the cliff is less certain.\n",
    "\n",
    "SARSA, on the other hand, appears to avoid the cliff edge, going up one more tile before starting over to the goal side. This also clearly solves the challenge of getting to the goal, but does so at an additional -2 cost over the truly optimal route.\n",
    "\n",
    "Why do you think these behaviors emerged the way they did?"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W10_Tutorial3",
   "provenance": [],
   "toc_visible": true
  },
  "interpreter": {
   "hash": "9516f62da91337f10c2adbe814d9c63a4b08f8271333386358218606edb781e3"
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3.7.11 64-bit ('pw3': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
